from typing import List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
from overrides import overrides
from chromadb.config import System
from chromadb.db.base import NotFoundError, UniqueConstraintError
from chromadb.db.system import SysDB
from chromadb.proto.convert import (
    from_proto_collection,
    from_proto_segment,
    to_proto_update_metadata,
    to_proto_segment,
    to_proto_segment_scope,
)
from chromadb.proto.coordinator_pb2 import (
    CreateCollectionRequest,
    CreateSegmentRequest,
    DeleteCollectionRequest,
    DeleteSegmentRequest,
    GetCollectionsRequest,
    GetCollectionsResponse,
    GetSegmentsRequest,
    UpdateCollectionRequest,
    UpdateSegmentRequest,
)
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
from chromadb.types import (
    Collection,
    Metadata,
    OptionalArgument,
    Segment,
    SegmentScope,
    Unspecified,
    UpdateMetadata,
)
from google.protobuf.empty_pb2 import Empty
import grpc


class GrpcSysDB(SysDB):
    """A gRPC implementation of the SysDB. In the distributed system, the SysDB is also
    called the 'Coordinator'. This implementation is used by Chroma frontend servers
    to call a remote SysDB (Coordinator) service."""

    _sys_db_stub: SysDBStub
    _channel: grpc.Channel
    _coordinator_url: str
    _coordinator_port: int

    def __init__(self, system: System):
        self._coordinator_url = system.settings.require("chroma_coordinator_host")
        # TODO: break out coordinator_port into a separate setting?
        self._coordinator_port = system.settings.require("chroma_server_grpc_port")
        return super().__init__(system)

    @overrides
    def start(self) -> None:
        # TODO: add retry policy here
        self._channel = grpc.insecure_channel(
            f"{self._coordinator_url}:{self._coordinator_port}"
        )
        self._sys_db_stub = SysDBStub(self._channel)  # type: ignore
        return super().start()

    @overrides
    def stop(self) -> None:
        self._channel.close()
        return super().stop()

    @overrides
    def reset_state(self) -> None:
        self._sys_db_stub.ResetState(Empty())
        return super().reset_state()

    @overrides
    def create_segment(self, segment: Segment) -> None:
        proto_segment = to_proto_segment(segment)
        request = CreateSegmentRequest(
            segment=proto_segment,
        )
        response = self._sys_db_stub.CreateSegment(request)
        if response.status.code == 409:
            raise UniqueConstraintError()

    @overrides
    def delete_segment(self, id: UUID) -> None:
        request = DeleteSegmentRequest(
            id=id.hex,
        )
        response = self._sys_db_stub.DeleteSegment(request)
        if response.status.code == 404:
            raise NotFoundError()

    @overrides
    def get_segments(
        self,
        id: Optional[UUID] = None,
        type: Optional[str] = None,
        scope: Optional[SegmentScope] = None,
        topic: Optional[str] = None,
        collection: Optional[UUID] = None,
    ) -> Sequence[Segment]:
        request = GetSegmentsRequest(
            id=id.hex if id else None,
            type=type,
            scope=to_proto_segment_scope(scope) if scope else None,
            topic=topic,
            collection=collection.hex if collection else None,
        )
        response = self._sys_db_stub.GetSegments(request)
        results: List[Segment] = []
        for proto_segment in response.segments:
            segment = from_proto_segment(proto_segment)
            results.append(segment)
        return results

    @overrides
    def update_segment(
        self,
        id: UUID,
        topic: OptionalArgument[Optional[str]] = Unspecified(),
        collection: OptionalArgument[Optional[UUID]] = Unspecified(),
        metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
    ) -> None:
        write_topic = None
        if topic != Unspecified():
            write_topic = cast(Union[str, None], topic)

        write_collection = None
        if collection != Unspecified():
            write_collection = cast(Union[UUID, None], collection)

        write_metadata = None
        if metadata != Unspecified():
            write_metadata = cast(Union[UpdateMetadata, None], metadata)

        request = UpdateSegmentRequest(
            id=id.hex,
            topic=write_topic,
            collection=write_collection.hex if write_collection else None,
            metadata=to_proto_update_metadata(write_metadata)
            if write_metadata
            else None,
        )

        if topic is None:
            request.ClearField("topic")
            request.reset_topic = True

        if collection is None:
            request.ClearField("collection")
            request.reset_collection = True

        if metadata is None:
            request.ClearField("metadata")
            request.reset_metadata = True

        self._sys_db_stub.UpdateSegment(request)

    @overrides
    def create_collection(
        self,
        id: UUID,
        name: str,
        metadata: Optional[Metadata] = None,
        dimension: Optional[int] = None,
        get_or_create: bool = False,
    ) -> Tuple[Collection, bool]:
        request = CreateCollectionRequest(
            id=id.hex,
            name=name,
            metadata=to_proto_update_metadata(metadata) if metadata else None,
            dimension=dimension,
            get_or_create=get_or_create,
        )
        response = self._sys_db_stub.CreateCollection(request)
        if response.status.code == 409:
            raise UniqueConstraintError()
        collection = from_proto_collection(response.collection)
        return collection, response.created

    @overrides
    def delete_collection(self, id: UUID) -> None:
        request = DeleteCollectionRequest(
            id=id.hex,
        )
        response = self._sys_db_stub.DeleteCollection(request)
        if response.status.code == 404:
            raise NotFoundError()

    @overrides
    def get_collections(
        self,
        id: Optional[UUID] = None,
        topic: Optional[str] = None,
        name: Optional[str] = None,
    ) -> Sequence[Collection]:
        request = GetCollectionsRequest(
            id=id.hex if id else None,
            topic=topic,
            name=name,
        )
        response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request)
        results: List[Collection] = []
        for collection in response.collections:
            results.append(from_proto_collection(collection))
        return results

    @overrides
    def update_collection(
        self,
        id: UUID,
        topic: OptionalArgument[str] = Unspecified(),
        name: OptionalArgument[str] = Unspecified(),
        dimension: OptionalArgument[Optional[int]] = Unspecified(),
        metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
    ) -> None:
        write_topic = None
        if topic != Unspecified():
            write_topic = cast(str, topic)

        write_name = None
        if name != Unspecified():
            write_name = cast(str, name)

        write_dimension = None
        if dimension != Unspecified():
            write_dimension = cast(Union[int, None], dimension)

        write_metadata = None
        if metadata != Unspecified():
            write_metadata = cast(Union[UpdateMetadata, None], metadata)

        request = UpdateCollectionRequest(
            id=id.hex,
            topic=write_topic,
            name=write_name,
            dimension=write_dimension,
            metadata=to_proto_update_metadata(write_metadata)
            if write_metadata
            else None,
        )
        if metadata is None:
            request.ClearField("metadata")
            request.reset_metadata = True

        self._sys_db_stub.UpdateCollection(request)

    def reset_and_wait_for_ready(self) -> None:
        self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)
